{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "66014f9c-60b8-4cb4-b5c0-3f387aaf01af",
   "metadata": {},
   "source": [
    "# LSTM+CRF序列标注\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96cf4504-e69b-4b1f-9721-31a2d4beae6a",
   "metadata": {},
   "source": [
    "## 概述\n",
    "\n",
    "序列标注指给定输入序列，给序列中每个Token进行标注标签的过程。序列标注问题通常用于从文本中进行信息抽取，包括分词(Word Segmentation)、词性标注(Position Tagging)、命名实体识别(Named Entity Recognition, NER)等。以命名实体识别为例：\n",
    "\n",
    "| 输入序列 | 清 | 华 | 大 | 学 | 座 | 落 | 于 | 首 | 都 | 北 | 京 |\n",
    "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |\n",
    "|输出标注| B | I | I | I | O | O | O | O | O | B | I |\n",
    "\n",
    "如上表所示，`清华大学` 和 `北京`是地名，需要将其识别，我们对每个输入的单词预测其标签，最后根据标签来识别实体。\n",
    "\n",
    "> 这里使用了一种常见的命名实体识别的标注方法——“BIOE”标注，将一个实体(Entity)的开头标注为B，其他部分标注为I，非实体标注为O。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b5517e9",
   "metadata": {},
   "source": [
    "## 设置运行环境\n",
    "\n",
    "由于资源限制，需开启性能优化模式，具体设置如下参数：\n",
    "\n",
    " max_device_memory=\"2GB\" : 设置设备可用的最大内存为2GB。\n",
    "\n",
    " mode=mindspore.GRAPH_MODE : 表示在GRAPH_MODE模式中运行。\n",
    "\n",
    " device_target=\"Ascend\" : 表示待运行的目标设备为Ascend。\n",
    "\n",
    " jit_config={\"jit_level\":\"O2\"} : 编译优化级别开启极致性能优化，使用下沉的执行方式。\n",
    "\n",
    " ascend_config={\"precision_mode\":\"allow_mix_precision\"} : 自动混合精度，自动将部分算子的精度降低到float16或bfloat16。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f01741d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore\n",
    "mindspore.set_context(max_device_memory=\"2GB\", mode=mindspore.GRAPH_MODE, device_target=\"Ascend\",  jit_config={\"jit_level\":\"O2\"}, ascend_config={\"precision_mode\":\"allow_mix_precision\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce77717a-465c-44cd-8d4d-37c6f8f132b8",
   "metadata": {},
   "source": [
    "## 条件随机场(Conditional Random Field, CRF)\n",
    "\n",
    "从上文的举例可以看到，对序列进行标注，实际上是对序列中每个Token进行标签预测，可以直接视作简单的多分类问题。但是序列标注不仅仅需要对单个Token进行分类预测，同时相邻Token直接有关联关系。以`清华大学`一词为例:\n",
    "\n",
    "| 输入序列 | 清 | 华 | 大 | 学 | |\n",
    "| --- | --- | --- | --- | --- | --- |\n",
    "| 输出标注 | B | I | I | I | √ |\n",
    "| 输出标注 | O | I | I | I | × |\n",
    "\n",
    "如上表所示，正确的实体中包含的4个Token有依赖关系，I前必须是B或I，而错误输出结果将`清`字标注为O，违背了这一依赖。将命名实体识别视为多分类问题，则每个词的预测概率都是独立的，易产生类似的问题，因此需要引入一种能够学习到此种关联关系的算法来保证预测结果的正确性。而条件随机场是适合此类场景的一种[概率图模型](https://en.wikipedia.org/wiki/Graphical_model)。下面对条件随机场的定义和参数化形式进行简析。\n",
    "\n",
    "> 考虑到序列标注问题的线性序列特点，本案例所述的条件随机场特指线性链条件随机场(Linear Chain CRF)\n",
    "\n",
    "设$x=\\{x_0, ..., x_n\\}$为输入序列，$y=\\{y_0, ..., y_n\\}，y \\in Y$为输出的标注序列，其中$n$为序列的最大长度，$Y$表示$x$对应的所有可能的输出序列集合。则输出序列$y$的概率为：\n",
    "\n",
    "$$\\begin{align}P(y|x) = \\frac{\\exp{(\\text{Score}(x, y)})}{\\sum_{y' \\in Y} \\exp{(\\text{Score}(x, y')})} \\qquad (1)\\end{align}$$\n",
    "\n",
    "设$x_i$, $y_i$为序列的第$i$个Token和对应的标签，则$\\text{Score}$需要能够在计算$x_i$和$y_i$的映射的同时，捕获相邻标签$y_{i-1}$和$y_{i}$之间的关系，因此我们定义两个概率函数：\n",
    "\n",
    "1. 发射概率函数$\\psi_\\text{EMIT}$：表示$x_i \\rightarrow y_i$的概率。\n",
    "2. 转移概率函数$\\psi_\\text{TRANS}$：表示$y_{i-1} \\rightarrow y_i$的概率。\n",
    "\n",
    "则可以得到$\\text{Score}$的计算公式：\n",
    "\n",
    "$$\\begin{align}\\text{Score}(x,y) = \\sum_i \\log \\psi_\\text{EMIT}(x_i \\rightarrow y_i) + \\log \\psi_\\text{TRANS}(y_{i-1} \\rightarrow y_i) \\qquad (2)\\end{align} $$\n",
    "\n",
    "设标签集合为$T$，构造大小为$|T|x|T|$的矩阵$\\textbf{P}$，用于存储标签间的转移概率；由编码层(可以为Dense、LSTM等)输出的隐状态$h$可以直接视作发射概率，此时$\\text{Score}$的计算公式可以转化为：\n",
    "\n",
    "$$\\begin{align}\\text{Score}(x,y) = \\sum_i h_i[y_i] + \\textbf{P}_{y_{i-1}, y_{i}} \\qquad (3)\\end{align}$$\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fccdc0f-9b42-42cd-86ef-ea7157c2e0f8",
   "metadata": {},
   "source": [
    "接下来我们根据上述公式，使用MindSpore来实现CRF的参数化形式。首先实现CRF层的前向训练部分，将CRF和损失函数做合并，选择分类问题常用的负对数似然函数(Negative Log Likelihood, NLL)，则有：\n",
    "\n",
    "$$\\begin{align}\\text{Loss} = -log(P(y|x)) \\qquad (4)\\end{align} $$\n",
    "\n",
    "由公式$(1)$可得，\n",
    "\n",
    "$$\\begin{align}\\text{Loss} = -log(\\frac{\\exp{(\\text{Score}(x, y)})}{\\sum_{y' \\in Y} \\exp{(\\text{Score}(x, y')})}) \\qquad (5)\\end{align} $$\n",
    "\n",
    "$$\\begin{align}= log(\\sum_{y' \\in Y} \\exp{(\\text{Score}(x, y')}) - \\text{Score}(x, y) \\end{align}$$\n",
    "\n",
    "根据公式$(5)$，我们称被减数为Normalizer，减数为Score，分别实现后相减得到最终Loss。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7420278-7285-4d50-a59e-0b906ec1b9d0",
   "metadata": {},
   "source": [
    "### Score计算\n",
    "\n",
    "首先根据公式$(3)$计算正确标签序列所对应的得分，这里需要注意，除了转移概率矩阵$\\textbf{P}$外，还需要维护两个大小为$|T|$的向量，分别作为序列开始和结束时的转移概率。同时我们引入了一个掩码矩阵$mask$，将多个序列打包为一个Batch时填充的值忽略，使得$\\text{Score}$计算仅包含有效的Token。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d09936ad-6b20-4423-9f57-8e14917e61d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_score(emissions, tags, seq_ends, mask, trans, start_trans, end_trans):\n",
    "    # emissions: (seq_length, batch_size, num_tags)\n",
    "    # tags: (seq_length, batch_size)\n",
    "    # mask: (seq_length, batch_size)\n",
    "\n",
    "    seq_length, batch_size = tags.shape\n",
    "    mask = mask.astype(emissions.dtype)\n",
    "\n",
    "    # 将score设置为初始转移概率\n",
    "    # shape: (batch_size,)\n",
    "    score = start_trans[tags[0]]\n",
    "    # score += 第一次发射概率\n",
    "    # shape: (batch_size,)\n",
    "    score += emissions[0, mnp.arange(batch_size), tags[0]]\n",
    "\n",
    "    for i in range(1, seq_length):\n",
    "        # 标签由i-1转移至i的转移概率（当mask == 1时有效）\n",
    "        # shape: (batch_size,)\n",
    "        score += trans[tags[i - 1], tags[i]] * mask[i]\n",
    "\n",
    "        # 预测tags[i]的发射概率（当mask == 1时有效）\n",
    "        # shape: (batch_size,)\n",
    "        score += emissions[i, mnp.arange(batch_size), tags[i]] * mask[i]\n",
    "\n",
    "    # 结束转移\n",
    "    # shape: (batch_size,)\n",
    "    last_tags = tags[seq_ends, mnp.arange(batch_size)]\n",
    "    # score += 结束转移概率\n",
    "    # shape: (batch_size,)\n",
    "    score += end_trans[last_tags]\n",
    "\n",
    "    return score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50efd4f2-2a19-4dc0-bdd4-d43e1345159d",
   "metadata": {},
   "source": [
    "### Normalizer计算\n",
    "\n",
    "根据公式$(5)$，Normalizer是$x$对应的所有可能的输出序列的Score的对数指数和(Log-Sum-Exp)。此时如果按穷举法进行计算，则需要将每个可能的输出序列Score都计算一遍，共有$|T|^{n}$个结果。这里我们采用动态规划算法，通过复用计算结果来提高效率。\n",
    "\n",
    "假设需要计算从第$0$至第$i$个Token所有可能的输出序列得分$\\text{Score}_{i}$，则可以先计算出从第$0$至第$i-1$个Token所有可能的输出序列得分$\\text{Score}_{i-1}$。因此，Normalizer可以改写为以下形式：\n",
    "\n",
    "$$log(\\sum_{y'_{0,i} \\in Y} \\exp{(\\text{Score}_i})) = log(\\sum_{y'_{0,i-1} \\in Y} \\exp{(\\text{Score}_{i-1} + h_{i} + \\textbf{P}})) \\qquad (6)$$\n",
    "\n",
    "其中$h_i$为第$i$个Token的发射概率，$\\textbf{P}$是转移矩阵。由于发射概率矩阵$h$和转移概率矩阵$\\textbf{P}$独立于$y$的序列路径计算，可以将其提出，可得：\n",
    "\n",
    "$$log(\\sum_{y'_{0,i} \\in Y} \\exp{(\\text{Score}_i})) = log(\\sum_{y'_{0,i-1} \\in Y} \\exp{(\\text{Score}_{i-1}})) + h_{i} + \\textbf{P} \\qquad (7)$$\n",
    "\n",
    "根据公式(7)，Normalizer的实现如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d9a0ef6a-1c3a-400e-9053-e0659e8f9e7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_normalizer(emissions, mask, trans, start_trans, end_trans):\n",
    "    # emissions: (seq_length, batch_size, num_tags)\n",
    "    # mask: (seq_length, batch_size)\n",
    "\n",
    "    seq_length = emissions.shape[0]\n",
    "\n",
    "    # 将score设置为初始转移概率，并加上第一次发射概率\n",
    "    # shape: (batch_size, num_tags)\n",
    "    score = start_trans + emissions[0]\n",
    "\n",
    "    for i in range(1, seq_length):\n",
    "        # 扩展score的维度用于总score的计算\n",
    "        # shape: (batch_size, num_tags, 1)\n",
    "        broadcast_score = score.expand_dims(2)\n",
    "\n",
    "        # 扩展emission的维度用于总score的计算\n",
    "        # shape: (batch_size, 1, num_tags)\n",
    "        broadcast_emissions = emissions[i].expand_dims(1)\n",
    "\n",
    "        # 根据公式(7)，计算score_i\n",
    "        # 此时broadcast_score是由第0个到当前Token所有可能路径\n",
    "        # 对应score的log_sum_exp\n",
    "        # shape: (batch_size, num_tags, num_tags)\n",
    "        next_score = broadcast_score + trans + broadcast_emissions\n",
    "\n",
    "        # 对score_i做log_sum_exp运算，用于下一个Token的score计算\n",
    "        # shape: (batch_size, num_tags)\n",
    "        next_score = ops.logsumexp(next_score, axis=1)\n",
    "\n",
    "        # 当mask == 1时，score才会变化\n",
    "        # shape: (batch_size, num_tags)\n",
    "        score = mnp.where(mask[i].expand_dims(1), next_score, score)\n",
    "\n",
    "    # 最后加结束转移概率\n",
    "    # shape: (batch_size, num_tags)\n",
    "    score += end_trans\n",
    "    # 对所有可能的路径得分求log_sum_exp\n",
    "    # shape: (batch_size,)\n",
    "    return ops.logsumexp(score, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68f81e2f-2c0c-44cf-ad45-3ca7254d5ad6",
   "metadata": {},
   "source": [
    "### Viterbi算法\n",
    "\n",
    "在完成前向训练部分后，需要实现解码部分。这里我们选择适合求解序列最优路径的[Viterbi算法](https://en.wikipedia.org/wiki/Viterbi_algorithm)。与计算Normalizer类似，使用动态规划求解所有可能的预测序列得分。不同的是在解码时同时需要将第$i$个Token对应的score取值最大的标签保存，供后续使用Viterbi算法求解最优预测序列使用。\n",
    "\n",
    "取得最大概率得分$\\text{Score}$，以及每个Token对应的标签历史$\\text{History}$后，根据Viterbi算法可以得到公式：\n",
    "\n",
    "$$P_{0,i} = max(P_{0, i-1}) + P_{i-1, i}$$\n",
    "\n",
    "从第0个至第$i$个Token对应概率最大的序列，只需要考虑从第0个至第$i-1$个Token对应概率最大的序列，以及从第$i$个至第$i-1$个概率最大的标签即可。因此我们逆序求解每一个概率最大的标签，构成最佳的预测序列。\n",
    "\n",
    "> 由于静态图语法限制，我们将Viterbi算法求解最佳预测序列的部分作为后处理函数，不纳入后续CRF层的实现。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c286935-74eb-413a-8e47-b8fb5264e6b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def viterbi_decode(emissions, mask, trans, start_trans, end_trans):\n",
    "    # emissions: (seq_length, batch_size, num_tags)\n",
    "    # mask: (seq_length, batch_size)\n",
    "\n",
    "    seq_length = mask.shape[0]\n",
    "\n",
    "    score = start_trans + emissions[0]\n",
    "    history = ()\n",
    "\n",
    "    for i in range(1, seq_length):\n",
    "        broadcast_score = score.expand_dims(2)\n",
    "        broadcast_emission = emissions[i].expand_dims(1)\n",
    "        next_score = broadcast_score + trans + broadcast_emission\n",
    "\n",
    "        # 求当前Token对应score取值最大的标签，并保存\n",
    "        indices = next_score.argmax(axis=1)\n",
    "        history += (indices,)\n",
    "\n",
    "        next_score = next_score.max(axis=1)\n",
    "        score = mnp.where(mask[i].expand_dims(1), next_score, score)\n",
    "\n",
    "    score += end_trans\n",
    "\n",
    "    return score, history\n",
    "\n",
    "def post_decode(score, history, seq_length):\n",
    "    # 使用Score和History计算最佳预测序列\n",
    "    batch_size = seq_length.shape[0]\n",
    "    seq_ends = seq_length - 1\n",
    "    # shape: (batch_size,)\n",
    "    best_tags_list = []\n",
    "\n",
    "    # 依次对一个Batch中每个样例进行解码\n",
    "    for idx in range(batch_size):\n",
    "        # 查找使最后一个Token对应的预测概率最大的标签，\n",
    "        # 并将其添加至最佳预测序列存储的列表中\n",
    "        best_last_tag = score[idx].argmax(axis=0)\n",
    "        best_tags = [int(best_last_tag.asnumpy())]\n",
    "\n",
    "        # 重复查找每个Token对应的预测概率最大的标签，加入列表\n",
    "        for hist in reversed(history[:seq_ends[idx]]):\n",
    "            best_last_tag = hist[idx][best_tags[-1]]\n",
    "            best_tags.append(int(best_last_tag.asnumpy()))\n",
    "\n",
    "        # 将逆序求解的序列标签重置为正序\n",
    "        best_tags.reverse()\n",
    "        best_tags_list.append(best_tags)\n",
    "\n",
    "    return best_tags_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37c81282-bc04-4f39-81c4-6d32acfb271e",
   "metadata": {},
   "source": [
    "### CRF层\n",
    "\n",
    "完成上述前向训练和解码部分的代码后，将其组装完整的CRF层。考虑到输入序列可能存在Padding的情况，CRF的输入需要考虑输入序列的真实长度，因此除发射矩阵和标签外，加入`seq_length`参数传入序列Padding前的长度，并实现生成mask矩阵的`sequence_mask`方法。\n",
    "\n",
    "综合上述代码，使用`nn.Cell`进行封装，最后实现完整的CRF层如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "120a39c7-c89d-4cd3-8a75-da27d2853f0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:02.452.307 [mindspore/run_check/_check_version.py:357] MindSpore version 2.2.14 and Ascend AI software package (Ascend Data Center Solution)version 7.0 does not match, the version of software package expect one of ['7.1']. Please refer to the match info on: https://www.mindspore.cn/install\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float64'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:549: UserWarning: The value of the smallest subnormal for <class 'numpy.float32'> type is zero.\n",
      "  setattr(self, word, getattr(machar, word).flat[0])\n",
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.855.681 [mindspore/run_check/_check_version.py:375] MindSpore version 2.2.14 and \"te\" wheel package version 7.0 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.861.443 [mindspore/run_check/_check_version.py:382] MindSpore version 2.2.14 and \"hccl\" wheel package version 7.0 does not match. For details, refer to the installation guidelines: https://www.mindspore.cn/install\n",
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:04.862.319 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 3\n",
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:05.864.215 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 2\n",
      "[WARNING] ME(97316:281472950050832,MainProcess):2024-09-06-10:50:06.866.175 [mindspore/run_check/_check_version.py:396] Please pay attention to the above warning, countdown: 1\n"
     ]
    }
   ],
   "source": [
    "import mindspore as ms\n",
    "import mindspore.nn as nn\n",
    "import mindspore.ops as ops\n",
    "import mindspore.numpy as mnp\n",
    "from mindspore.common.initializer import initializer, Uniform\n",
    "\n",
    "def sequence_mask(seq_length, max_length, batch_first=False):\n",
    "    \"\"\"根据序列实际长度和最大长度生成mask矩阵\"\"\"\n",
    "    range_vector = mnp.arange(0, max_length, 1, seq_length.dtype)\n",
    "    result = range_vector < seq_length.view(seq_length.shape + (1,))\n",
    "    if batch_first:\n",
    "        return result.astype(ms.int64)\n",
    "    return result.astype(ms.int64).swapaxes(0, 1)\n",
    "\n",
    "class CRF(nn.Cell):\n",
    "    def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:\n",
    "        if num_tags <= 0:\n",
    "            raise ValueError(f'invalid number of tags: {num_tags}')\n",
    "        super().__init__()\n",
    "        if reduction not in ('none', 'sum', 'mean', 'token_mean'):\n",
    "            raise ValueError(f'invalid reduction: {reduction}')\n",
    "        self.num_tags = num_tags\n",
    "        self.batch_first = batch_first\n",
    "        self.reduction = reduction\n",
    "        self.start_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='start_transitions')\n",
    "        self.end_transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags,)), name='end_transitions')\n",
    "        self.transitions = ms.Parameter(initializer(Uniform(0.1), (num_tags, num_tags)), name='transitions')\n",
    "\n",
    "    def construct(self, emissions, tags=None, seq_length=None):\n",
    "        if tags is None:\n",
    "            return self._decode(emissions, seq_length)\n",
    "        return self._forward(emissions, tags, seq_length)\n",
    "\n",
    "    def _forward(self, emissions, tags=None, seq_length=None):\n",
    "        if self.batch_first:\n",
    "            batch_size, max_length = tags.shape\n",
    "            emissions = emissions.swapaxes(0, 1)\n",
    "            tags = tags.swapaxes(0, 1)\n",
    "        else:\n",
    "            max_length, batch_size = tags.shape\n",
    "\n",
    "        if seq_length is None:\n",
    "            seq_length = mnp.full((batch_size,), max_length, ms.int64)\n",
    "\n",
    "        mask = sequence_mask(seq_length, max_length)\n",
    "\n",
    "        # shape: (batch_size,)\n",
    "        numerator = compute_score(emissions, tags, seq_length-1, mask, self.transitions, self.start_transitions, self.end_transitions)\n",
    "        # shape: (batch_size,)\n",
    "        denominator = compute_normalizer(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)\n",
    "        # shape: (batch_size,)\n",
    "        llh = denominator - numerator\n",
    "\n",
    "        if self.reduction == 'none':\n",
    "            return llh\n",
    "        if self.reduction == 'sum':\n",
    "            return llh.sum()\n",
    "        if self.reduction == 'mean':\n",
    "            return llh.mean()\n",
    "        return llh.sum() / mask.astype(emissions.dtype).sum()\n",
    "\n",
    "    def _decode(self, emissions, seq_length=None):\n",
    "        if self.batch_first:\n",
    "            batch_size, max_length = emissions.shape[:2]\n",
    "            emissions = emissions.swapaxes(0, 1)\n",
    "        else:\n",
    "            batch_size, max_length = emissions.shape[:2]\n",
    "\n",
    "        if seq_length is None:\n",
    "            seq_length = mnp.full((batch_size,), max_length, ms.int64)\n",
    "\n",
    "        mask = sequence_mask(seq_length, max_length)\n",
    "\n",
    "        return viterbi_decode(emissions, mask, self.transitions, self.start_transitions, self.end_transitions)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af6c36c4-82e0-40fa-8d97-03cd6961d594",
   "metadata": {},
   "source": [
    "## BiLSTM+CRF模型\n",
    "\n",
    "在实现CRF后，我们设计一个双向LSTM+CRF的模型来进行命名实体识别任务的训练。模型结构如下：\n",
    "\n",
    "```text\n",
    "nn.Embedding -> nn.LSTM -> nn.Dense -> CRF\n",
    "```\n",
    "\n",
    "其中LSTM提取序列特征，经过Dense层变换获得发射概率矩阵，最后送入CRF层。具体实现如下："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c07555e3-f2a2-4c25-beff-5a78491ab2d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BiLSTM_CRF(nn.Cell):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_tags, padding_idx=0):\n",
    "        super().__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, bidirectional=True, batch_first=True)\n",
    "        self.hidden2tag = nn.Dense(hidden_dim, num_tags, 'he_uniform')\n",
    "        self.crf = CRF(num_tags, batch_first=True)\n",
    "\n",
    "    def construct(self, inputs, seq_length, tags=None):\n",
    "        embeds = self.embedding(inputs)\n",
    "        outputs, _ = self.lstm(embeds, seq_length=seq_length)\n",
    "        feats = self.hidden2tag(outputs)\n",
    "\n",
    "        crf_outs = self.crf(feats, tags, seq_length)\n",
    "        return crf_outs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68d4f34a-aaca-46e9-8aa1-2ef894952133",
   "metadata": {},
   "source": [
    "完成模型设计后，我们生成两句例子和对应的标签，并构造词表和标签表。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fa53f535-34bc-49a3-b769-e85d3a184cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_dim = 16\n",
    "hidden_dim = 32\n",
    "\n",
    "training_data = [(\n",
    "    \"清 华 大 学 坐 落 于 首 都 北 京\".split(),\n",
    "    \"B I I I O O O O O B I\".split()\n",
    "), (\n",
    "    \"重 庆 是 一 个 魔 幻 城 市\".split(),\n",
    "    \"B I O O O O O O O\".split()\n",
    ")]\n",
    "\n",
    "word_to_idx = {}\n",
    "word_to_idx['<pad>'] = 0\n",
    "for sentence, tags in training_data:\n",
    "    for word in sentence:\n",
    "        if word not in word_to_idx:\n",
    "            word_to_idx[word] = len(word_to_idx)\n",
    "\n",
    "tag_to_idx = {\"B\": 0, \"I\": 1, \"O\": 2}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "5c255358-938b-4e88-a810-2fe4e616a044",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(word_to_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7130410c-a2ce-4086-8f79-17580197c2d7",
   "metadata": {},
   "source": [
    "接下来实例化模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e6d3f2bb-0ea7-457c-8a74-a8e0ea78a109",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BiLSTM_CRF(len(word_to_idx), embedding_dim, hidden_dim, len(tag_to_idx))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "babd02a3-fb1d-4b9f-8b75-127268ba5e2e",
   "metadata": {},
   "source": [
    "将生成的数据打包成Batch，按照序列最大长度，对长度不足的序列进行填充，分别返回输入序列、输出标签和序列长度构成的Tensor。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7899b6e6-95a4-4ffe-ba49-b27cebe8b306",
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_sequence(seqs, word_to_idx, tag_to_idx):\n",
    "    seq_outputs, label_outputs, seq_length = [], [], []\n",
    "    max_len = max([len(i[0]) for i in seqs])\n",
    "\n",
    "    for seq, tag in seqs:\n",
    "        seq_length.append(len(seq))\n",
    "        idxs = [word_to_idx[w] for w in seq]\n",
    "        labels = [tag_to_idx[t] for t in tag]\n",
    "        idxs.extend([word_to_idx['<pad>'] for i in range(max_len - len(seq))])\n",
    "        labels.extend([tag_to_idx['O'] for i in range(max_len - len(seq))])\n",
    "        seq_outputs.append(idxs)\n",
    "        label_outputs.append(labels)\n",
    "\n",
    "    return ms.Tensor(seq_outputs, ms.int64), \\\n",
    "            ms.Tensor(label_outputs, ms.int64), \\\n",
    "            ms.Tensor(seq_length, ms.int64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0e984dcb-6f89-4520-940f-620cdb997529",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((2, 11), (2, 11), (2,))"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data, label, seq_length = prepare_sequence(training_data, word_to_idx, tag_to_idx)\n",
    "data.shape, label.shape, seq_length.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "473357b0-d008-455a-8190-4db02963020e",
   "metadata": {},
   "source": [
    "最后我们来观察训练500个step后的模型效果(加载训练好的ckpt），首先使用模型预测可能的路径得分以及候选序列。（请忽略[ERROR][WARNING]）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65304086-de03-4edc-b3a2-3ca5e43c4795",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mindspore import load_param_into_net, load_checkpoint\n",
    "from download import download\n",
    "\n",
    "# download ckpt\n",
    "lstm_crf_url = \"https://modelers.cn/coderepo/web/v1/file/MindSpore-Lab/cluoud_obs/main/media/examples/mindspore-courses/orange-pi-online-infer/08-LSTM%2BCRF/lstm-crf.ckpt\"\n",
    "path = \"./lstm-crf.ckpt\"\n",
    "ckpt_path = download(lstm_crf_url, path, replace=True)\n",
    "\n",
    "load_checkpoint=load_checkpoint(ckpt_path)\n",
    "load_param_into_net(model, load_checkpoint)\n",
    "score, history = model(data, seq_length)\n",
    "score"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5dc4378c-d91e-4d19-b3e8-f1bf511369f3",
   "metadata": {},
   "source": [
    "使用后处理函数进行预测得分的后处理。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ed5ff78a-099c-41f8-ac1e-8f9fbfab7659",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[0, 1, 1, 1, 2, 2, 2, 2, 2, 0, 1], [0, 1, 2, 2, 2, 2, 2, 2, 2]]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict = post_decode(score, history, seq_length)\n",
    "predict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2bd6b64-c0e6-41fd-8abc-3a057e28ba57",
   "metadata": {},
   "source": [
    "最后将预测的index序列转换为标签序列，打印输出结果，查看效果。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "df505983-9c1f-4e09-8a67-1cd78ce69fd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "idx_to_tag = {idx: tag for tag, idx in tag_to_idx.items()}\n",
    "\n",
    "def sequence_to_tag(sequences, idx_to_tag):\n",
    "    outputs = []\n",
    "    for seq in sequences:\n",
    "        outputs.append([idx_to_tag[i] for i in seq])\n",
    "    return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "e30d43c7-7c09-445e-94a4-33fccfaeaf11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['B', 'I', 'I', 'I', 'O', 'O', 'O', 'O', 'O', 'B', 'I'],\n",
       " ['B', 'I', 'O', 'O', 'O', 'O', 'O', 'O', 'O']]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sequence_to_tag(predict, idx_to_tag)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "8c9da313289c39257cb28b126d2dadd33153d4da4d524f730c81a4aaccbd2ca7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
